{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3362a434",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset chn_senti_corp (/Users/lee/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)\n",
      "Loading cached processed dataset at /Users/lee/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85/cache-e4f30e09e5a06112.arrow\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(9192,\n",
       " '选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "#定义数据集\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, split):\n",
    "        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)\n",
    "\n",
    "        def f(data):\n",
    "            return len(data['text']) > 30\n",
    "\n",
    "        self.dataset = dataset.filter(f)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        text = self.dataset[i]['text']\n",
    "\n",
    "        return text\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "\n",
    "len(dataset), dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e70a58c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "\n",
    "#加载字典和分词工具\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e59695a4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "574\n",
      "[CLS] 只 住 了 一 个 晚 上 ， 房 间 很 小 ， 不 [MASK] 考 虑 到 只 是 一 个 三 星 级 的 宾 馆 [SEP]\n",
      "过\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 30]),\n",
       " torch.Size([16, 30]),\n",
       " torch.Size([16, 30]),\n",
       " torch.Size([16]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def collate_fn(data):\n",
    "    #编码\n",
    "    data = token.batch_encode_plus(batch_text_or_text_pairs=data,\n",
    "                                   truncation=True,\n",
    "                                   padding='max_length',\n",
    "                                   max_length=30,\n",
    "                                   return_tensors='pt',\n",
    "                                   return_length=True)\n",
    "\n",
    "    #input_ids:编码之后的数字\n",
    "    #attention_mask:是补零的位置是0,其他位置是1\n",
    "    input_ids = data['input_ids']\n",
    "    attention_mask = data['attention_mask']\n",
    "    token_type_ids = data['token_type_ids']\n",
    "\n",
    "    #把第15个词固定替换为mask\n",
    "    labels = input_ids[:, 15].reshape(-1).clone()\n",
    "    input_ids[:, 15] = token.get_vocab()[token.mask_token]\n",
    "\n",
    "    #print(data['length'], data['length'].max())\n",
    "\n",
    "    return input_ids, attention_mask, token_type_ids, labels\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "print(len(loader))\n",
    "print(token.decode(input_ids[0]))\n",
    "print(token.decode(labels[0]))\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f620d0e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 30, 768])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertModel\n",
    "\n",
    "#加载预训练模型\n",
    "pretrained = BertModel.from_pretrained('bert-base-chinese')\n",
    "\n",
    "#不训练,不需要计算梯度\n",
    "for param in pretrained.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "#模型试算\n",
    "out = pretrained(input_ids=input_ids,\n",
    "           attention_mask=attention_mask,\n",
    "           token_type_ids=token_type_ids)\n",
    "\n",
    "out.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b273bf7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 21128])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.decoder = torch.nn.Linear(768, token.vocab_size, bias=False)\n",
    "        self.bias = torch.nn.Parameter(torch.zeros(token.vocab_size))\n",
    "        self.decoder.bias = self.bias\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        with torch.no_grad():\n",
    "            out = pretrained(input_ids=input_ids,\n",
    "                             attention_mask=attention_mask,\n",
    "                             token_type_ids=token_type_ids)\n",
    "\n",
    "        out = self.decoder(out.last_hidden_state[:, 15])\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model(input_ids=input_ids,\n",
    "      attention_mask=attention_mask,\n",
    "      token_type_ids=token_type_ids).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1bd44a7c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0 9.984944343566895 0.0\n",
      "0 50 8.533296585083008 0.125\n",
      "0 100 6.3379387855529785 0.25\n",
      "0 150 5.006066799163818 0.375\n",
      "0 200 4.933525562286377 0.3125\n",
      "0 250 5.3085222244262695 0.25\n",
      "0 300 4.5820441246032715 0.25\n",
      "0 350 3.11991548538208 0.5625\n",
      "0 400 4.346089839935303 0.3125\n",
      "0 450 3.8530707359313965 0.375\n",
      "0 500 1.6403313875198364 0.8125\n",
      "0 550 2.572502851486206 0.5625\n",
      "1 0 1.9827029705047607 0.8125\n",
      "1 50 2.5521962642669678 0.5625\n",
      "1 100 2.748547077178955 0.5625\n",
      "1 150 1.0036667585372925 0.875\n",
      "1 200 2.2741544246673584 0.625\n",
      "1 250 1.6249862909317017 0.8125\n",
      "1 300 1.682465672492981 0.75\n",
      "1 350 1.7401515245437622 0.6875\n",
      "1 400 1.65713369846344 0.8125\n",
      "1 450 1.53067946434021 0.8125\n",
      "1 500 1.451479196548462 0.8125\n",
      "1 550 2.535844087600708 0.5625\n",
      "2 0 1.3117371797561646 0.75\n",
      "2 50 0.6940720081329346 0.9375\n",
      "2 100 1.2575973272323608 0.8125\n",
      "2 150 1.3493293523788452 0.5625\n",
      "2 200 0.7944004535675049 0.9375\n",
      "2 250 1.33329176902771 0.625\n",
      "2 300 1.009347915649414 0.75\n",
      "2 350 1.3858131170272827 0.75\n",
      "2 400 0.5483282208442688 0.9375\n",
      "2 450 0.4145049750804901 0.9375\n",
      "2 500 1.3453038930892944 0.75\n",
      "2 550 0.672179102897644 0.875\n",
      "3 0 0.8804690837860107 0.875\n",
      "3 50 0.4713229537010193 1.0\n",
      "3 100 0.47764867544174194 1.0\n",
      "3 150 0.7397276759147644 0.875\n",
      "3 200 0.6919273138046265 0.8125\n",
      "3 250 0.5416679382324219 0.9375\n",
      "3 300 0.6753159761428833 0.75\n",
      "3 350 1.0199646949768066 0.8125\n",
      "3 400 0.6875881552696228 0.875\n",
      "3 450 0.4993954598903656 0.9375\n",
      "3 500 0.7261351346969604 0.875\n",
      "3 550 0.48293957114219666 0.875\n",
      "4 0 0.5171896815299988 0.875\n",
      "4 50 0.34932607412338257 1.0\n",
      "4 100 0.30208516120910645 1.0\n",
      "4 150 0.47497159242630005 0.9375\n",
      "4 200 0.3869628310203552 0.9375\n",
      "4 250 0.43719175457954407 1.0\n",
      "4 300 0.1504955291748047 1.0\n",
      "4 350 0.4130438566207886 0.8125\n",
      "4 400 0.2014932781457901 1.0\n",
      "4 450 0.5038660764694214 0.8125\n",
      "4 500 0.29774224758148193 1.0\n",
      "4 550 0.38373318314552307 0.875\n"
     ]
    }
   ],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "#训练\n",
    "optimizer = AdamW(model.parameters(), lr=5e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "model.train()\n",
    "for epoch in range(5):\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader):\n",
    "        out = model(input_ids=input_ids,\n",
    "                    attention_mask=attention_mask,\n",
    "                    token_type_ids=token_type_ids)\n",
    "\n",
    "        loss = criterion(out, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        if i % 50 == 0:\n",
    "            out = out.argmax(dim=1)\n",
    "            accuracy = (out == labels).sum().item() / len(labels)\n",
    "\n",
    "            print(epoch, i, loss.item(), accuracy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "275dd1b6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset chn_senti_corp (/Users/lee/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)\n",
      "Loading cached processed dataset at /Users/lee/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85/cache-3e9b343e1ee7b81d.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[CLS] 某 些 酒 店 人 员 对 待 顾 客 不 诚 恳 。 [MASK] 给 你 换 房 间 ， 骗 你 说 是 价 钱 贵 [SEP]\n",
      "说 说\n",
      "1\n",
      "[CLS] 这 套 书 ， 是 听 别 人 介 绍 的 ， 回 家 [MASK] 3 岁 的 女 儿 一 起 分 享 了 故 事 内 [SEP]\n",
      "与 与\n",
      "2\n",
      "[CLS] flash 分 编 程 和 动 画 两 部 分 ， 这 本 书 [MASK] 的 是 flash [UNK] ， 可 是 脚 本 的 运 用 任 [SEP]\n",
      "说 说\n",
      "3\n",
      "[CLS] 作 者 力 从 马 克 思 注 意 经 济 学 角 度 [MASK] 剖 析 当 代 中 国 经 济 细 心 的 人 会 [SEP]\n",
      "来 来\n",
      "4\n",
      "[CLS] 唯 美 的 人 物 ， 唯 美 的 故 事 接 触 过 [MASK] 个 此 时 代 背 景 下 的 故 事 ， 不 过 [SEP]\n",
      "多 多\n",
      "5\n",
      "[CLS] 预 装 的 linux 不 是 直 接 进 入 系 统 的 ， [MASK] 方 便 测 试 机 器 。 随 机 光 盘 全 英 [SEP]\n",
      "不 不\n",
      "6\n",
      "[CLS] 选 择 的 事 例 太 离 奇 了 ， 夸 大 了 心 [MASK] 咨 询 的 现 实 意 义 ， 让 人 失 去 了 [SEP]\n",
      "理 理\n",
      "7\n",
      "[CLS] 洛 尔 卡 深 受 我 的 老 师 推 崇 。 在 那 [MASK] 神 奇 的 大 学 夜 晚 ， 跟 老 师 谈 论 [SEP]\n",
      "个 个\n",
      "8\n",
      "[CLS] 房 间 比 中 州 皇 冠 之 类 的 大 多 了 ， [MASK] 务 也 还 可 以 ； 早 餐 太 次 ， 品 种 [SEP]\n",
      "服 服\n",
      "9\n",
      "[CLS] 这 本 书 是 我 无 意 中 从 网 上 发 现 的 [MASK] 看 了 简 介 觉 得 不 错 就 继 续 把 整 [SEP]\n",
      "， ，\n",
      "10\n",
      "[CLS] 外 观 时 尚 ， 配 置 均 衡 ， 机 器 散 热 [MASK] 不 错 ， 用 了 大 概 一 个 月 ， 感 觉 [SEP]\n",
      "也 也\n",
      "11\n",
      "[CLS] 地 理 位 置 不 错 ， 房 间 环 境 也 可 以 [MASK] 洗 衣 速 度 快 ， 值 得 称 赞 ， 唯 一 [SEP]\n",
      "， ，\n",
      "12\n",
      "[CLS] 机 器 的 上 盖 外 表 面 是 钢 琴 烤 漆 的 [MASK] 名 副 其 实 的 指 纹 收 集 器 ， 看 来 [SEP]\n",
      "， ，\n",
      "13\n",
      "[CLS] 键 盘 很 生 硬 ， 没 有 手 感 。 标 配 内 [MASK] [UNK] 为 两 根 [UNK] 的 组 成 ， 很 郁 闷 ， [SEP]\n",
      "存 存\n",
      "14\n",
      "[CLS] 原 註 是 英 文 的, 為 何 不 可 留 下 英 [MASK] 呢, 有 十 有 英 不 是 更 好 嗎? 全 [SEP]\n",
      "文 文\n",
      "0.7041666666666667\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "def test():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader_test):\n",
    "\n",
    "        if i == 15:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "        out = out.argmax(dim=1)\n",
    "        correct += (out == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "        print(token.decode(input_ids[0]))\n",
    "        print(token.decode(labels[0]), token.decode(labels[0]))\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
